from collections import deque
from itertools import combinations

import numpy
from qat.lang.AQASM.gates import RY
from qat.lang.AQASM.routines import QRoutine

from qaths.bugfix.routine_numpy_integer import bugfix_routine_indices_should_be_integer
from qaths.utils.endian import reverse_endian

_zero = numpy.array([1, 0], dtype=numpy.complex)
_one = numpy.array([0, 1], dtype=numpy.complex)


def _subvector_norms(data_vector: numpy.ndarray) -> numpy.ndarray:
    res = []
    size = data_vector.size

    assert size == (
        1 << (size.bit_length() - 1)
    ), "The given vector should have 2^n elements."

    for depth in range(1, size.bit_length()):
        chunk_size = size // 2 ** depth
        parent_chunk_size = size // 2 ** (depth - 1)
        parent_norms = [
            numpy.linalg.norm(
                data_vector[
                    (parent_chunk_id * parent_chunk_size) : (
                        (parent_chunk_id + 1) * parent_chunk_size
                    )
                ]
            )
            for parent_chunk_id in range(2 ** (depth - 1))
        ]
        norms = [
            numpy.linalg.norm(
                data_vector[(chunk_id * chunk_size) : ((chunk_id + 1) * chunk_size)]
            )
            / parent_norms[chunk_id // 2]
            if parent_norms[chunk_id // 2]
            else 0
            for chunk_id in range(0, 2 ** depth, 2)
        ]
        res += norms
    return numpy.array(res)


def _ctrl_rotations(gate, rotations):
    r"""
    Apply a control rotation of angle rotations[state] for each state.

    Args:
        gate: a 1-qubit parametrized gate
        rotations: the array containing the rotation angle for each state
    """

    def build_index(qbits):
        """ build a mask from a qbit list """
        index = 0
        for i in qbits:
            index = index | (1 << i)
        return index

    nb_address_qbits = len(rotations).bit_length() - 1
    rout = QRoutine()
    # Force the arity of the routine.
    rout.arity = nb_address_qbits + 1
    if rotations[0]:
        rout.apply(gate(rotations[0]), nb_address_qbits)
    for i in range(1, nb_address_qbits + 1):
        for qbits in list(combinations(list(range(0, nb_address_qbits)), i)):
            index = build_index(qbits)
            rotations[index] -= rotations[0]
            for k in range(1, i):
                for inner_combination in list(combinations(qbits, k)):
                    rotations[index] -= rotations[build_index(inner_combination)]
            if rotations[index]:
                ctrl_gate = gate(rotations[index])
                for _ in range(i):
                    ctrl_gate = ctrl_gate.ctrl()
                rout.apply(ctrl_gate, qbits, nb_address_qbits)
    return rout


def QRAM(matrix: numpy.ndarray, msb_first: bool = True) -> QRoutine:
    r"""
    The QRAM oracle performs a unitary operator of the form:

     :math:`|i\rangle |0\rangle \mapsto |i\rangle |x(i)\rangle`

     where :math:`x(i) \in \mathbb{R}^d` is real vector.

     Args:
         matrix (np.ndarray, np.matrix, list<list<float>>): a matrix whose rows are
             the :math:`x(i)`. Its number of rows and columns should be both
             powers of 2
         msb_first (bool, optional): the bit order to be used. If set to True, a
             lsb-first convention will be used. Default is False.
     """

    if not numpy.log2(matrix.shape[0]).is_integer():
        raise ValueError("The number of rows must be a power of 2")
    if not numpy.log2(matrix.shape[1]).is_integer():
        zeros = numpy.zeros(
            (matrix.shape[0], 1 << numpy.ceil(numpy.log2(matrix.shape[1]))), dtype=int
        )
        zeros[: matrix.shape[0], : matrix.shape[1]] = matrix
        matrix = zeros

    if msb_first:
        matrix = reverse_endian(matrix)

    arity = numpy.round(numpy.log2(matrix.shape[0])).astype(int) + numpy.round(
        numpy.log2(matrix.shape[1])
    ).astype(int)
    rout = QRoutine(arity=arity)
    nb_address_qbits = matrix.shape[0].bit_length() - 1
    nb_data_qbits = matrix.shape[1].bit_length() - 1

    norms = numpy.empty((matrix.shape[0], matrix.shape[1] - 1))
    leaf_start = (1 << (nb_data_qbits - 1)) - 1
    for i in range(1 << nb_address_qbits):
        norms[i] = _subvector_norms(matrix[i])
        for j in range(leaf_start):
            norms[i, j] = 2 * numpy.arccos(norms[i, j])
        for j in range(0, matrix.shape[1], 2):
            norm_index = leaf_start + j // 2
            if matrix[i, j] >= 0 and matrix[i, j + 1] >= 0:
                norms[i, norm_index] = 2 * (numpy.arccos(norms[i, norm_index]))
            elif matrix[i, j] < 0 and matrix[i, j + 1] < 0:
                norms[i, norm_index] = 2 * (
                    numpy.arccos(norms[i, norm_index]) + numpy.pi
                )
            elif matrix[i, j] < 0:
                norms[i, norm_index] = numpy.pi + 2 * numpy.arccos(
                    numpy.sqrt(1 - norms[i, norm_index] ** 2)
                )
            else:
                norms[i, norm_index] = -numpy.pi + 2 * numpy.arccos(
                    numpy.sqrt(1 - norms[i, norm_index] ** 2)
                )
    reg_address = deque(list(range(nb_address_qbits)))
    nb_rotations = 1
    for i in range(nb_address_qbits + nb_data_qbits - 1, nb_address_qbits - 1, -1):
        rout.apply(
            _ctrl_rotations(
                RY, norms[:, (nb_rotations - 1) : (2 * nb_rotations - 1)].flatten()
            ),
            reg_address,
            i,
        )
        nb_rotations = nb_rotations << 1
        reg_address.appendleft(i)
    return rout


def PermutationOnlyQRAM(matrix: numpy.ndarray) -> QRoutine:
    """

    :param matrix:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)
    m = numpy.round(numpy.log2(matrix.shape[1])).astype(int)

    # 0.5 is used here because the first positive non-zero integer is greater than 0.5
    # Any value in (0, 1) would be fine.
    permutation = (numpy.abs(matrix) > 0.5).astype(numpy.int)

    # Handle the case where the given matrix has no non-zero entry in some rows
    # by inserting ones on the diagonal.
    has_non_zero = numpy.dot(
        permutation, numpy.ones((permutation.shape[1],), dtype=numpy.int)
    )

    for i in numpy.argwhere(has_non_zero == 0):
        permutation[i, i] = 1

    rout = QRoutine(arity=n + m)
    x = list(range(n))
    perm = list(range(n, n + m))
    rout.apply(QRAM(permutation, msb_first=True), x, perm)

    return rout


def SignOnlyQRAM(matrix: numpy.ndarray) -> QRoutine:
    """

    :param matrix:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)

    # Rows with no non-zero entry will be arbitrarily marked as positive
    is_negative = [
        True if (row.nonzero()[0].size > 0 and row[row.nonzero()] < 0) else False
        for row in matrix
    ]
    binary_signs = numpy.array([_one if neg else _zero for neg in is_negative])

    rout = QRoutine(arity=n + 1)
    x = list(range(n))
    sign = bugfix_routine_indices_should_be_integer(n)
    rout.apply(QRAM(binary_signs, msb_first=True), x, sign)

    return rout


def UnsignedIntOnlyQRAM(matrix: numpy.ndarray, int_size: int) -> QRoutine:
    """

    :param matrix:
    :param int_size:
    :return:
    """

    if int_size < 1:
        raise RuntimeError("The int_size should be 1 or more.")

    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)

    values = []
    for row in range(2 ** n):
        nnz = matrix[row].nonzero()[0]
        if nnz.size > 0:
            values.append(matrix[row][nnz[0]])
        else:
            values.append(0)
    values = numpy.abs(numpy.array(values))

    if numpy.any(values >= 2 ** int_size):
        raise OverflowError(
            "The given int_size ({}) is not enough to store the values {}".format(
                int_size, values[values >= 2 ** int_size]
            )
        )
    binary_values = numpy.zeros((2 ** n, 2 ** int_size), dtype=numpy.float)
    for i, val in enumerate(values):
        binary_values[i, int(val)] = 1.0

    rout = QRoutine(arity=n + int_size)

    x = list(range(n))
    w = list(range(n, n + int_size))

    rout.apply(QRAM(binary_values, msb_first=True), x, w)
    return rout


def UnsignedIntQRAM(matrix: numpy.ndarray, int_size: int) -> QRoutine:
    """

    :param matrix:
    :param int_size:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)
    m = numpy.round(numpy.log2(matrix.shape[1])).astype(int)

    rout = QRoutine(arity=n + m + int_size)
    x = list(range(n))
    perm = list(range(n, n + m))
    w = list(range(n + m, n + m + int_size))

    rout.apply(PermutationOnlyQRAM(matrix), x, perm)
    rout.apply(UnsignedIntOnlyQRAM(matrix, int_size), x, w)

    return rout


def SignedIntQRAM(matrix: numpy.ndarray, int_size: int) -> QRoutine:
    """

    :param matrix:
    :param int_size:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)
    m = numpy.round(numpy.log2(matrix.shape[1])).astype(int)

    rout = QRoutine(arity=n + m + int_size + 1)
    x = list(range(n))
    perm = list(range(n, n + m))
    w = list(range(n + m, n + m + int_size))
    s = bugfix_routine_indices_should_be_integer(n + m + int_size)

    rout.apply(PermutationOnlyQRAM(matrix), x, perm)
    rout.apply(UnsignedIntOnlyQRAM(matrix, int_size), x, w)
    rout.apply(SignOnlyQRAM(matrix), x, s)

    return rout


def PermutationSignQRAM(matrix: numpy.ndarray) -> QRoutine:
    """

    :param matrix:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)

    rout = QRoutine(arity=2 * n + 1)
    x = list(range(n))
    perm = list(range(n, 2 * n))
    s = bugfix_routine_indices_should_be_integer(2 * n)

    rout.apply(PermutationOnlyQRAM(matrix), x, perm)
    rout.apply(SignOnlyQRAM(matrix), x, s)

    return rout


def OnesOrZerosSignQRAM(matrix: numpy.ndarray) -> QRoutine:
    """

    :param matrix:
    :return:
    """
    n = numpy.round(numpy.log2(matrix.shape[0])).astype(int)
    m = numpy.round(numpy.log2(matrix.shape[1])).astype(int)

    rout = QRoutine(arity=n + m + 2)
    x = list(range(n))
    perm = list(range(n, n + m))
    w = bugfix_routine_indices_should_be_integer(n + m)
    s = bugfix_routine_indices_should_be_integer(n + m + 1)

    rout.apply(PermutationOnlyQRAM(matrix), x, perm)
    rout.apply(UnsignedIntOnlyQRAM(matrix, 1), x, w)
    rout.apply(SignOnlyQRAM(matrix), x, s)

    return rout
